{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7f66a3af1f30>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "torch.set_printoptions(edgeitems=2, linewidth=75)\n",
    "torch.manual_seed(123)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../data/p1ch7/cifar-10-python.tar.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "753331e0836043b69d0601111e96b149",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ../data/p1ch7/cifar-10-python.tar.gz to ../data/p1ch7/\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "from torchvision import datasets\n",
    "\n",
    "data_path = '../data/p1ch7/'\n",
    "cifar10 = datasets.CIFAR10(data_path,train = True,download=True)\n",
    "cifar10_val = datasets.CIFAR10(data_path,train = False,download=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset CIFAR10\n",
       "    Number of datapoints: 50000\n",
       "    Root location: ../data/p1ch7/\n",
       "    Split: Train"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cifar10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class_names = ['airplane','automobile','bird','cat','deer',\n",
    "               'dog','frog','horse','ship','truck']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(<PIL.Image.Image image mode=RGB size=32x32 at 0x7F66A18AF1D0>,\n",
       " 1,\n",
       " 'automobile')"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "img,label = cifar10[99]\n",
    "img,label,class_names[label]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAfCklEQVR4nO2de4xd13Xev3Xf8+Rwhq8RRYmiSIuUZL1Kq0rlGrLSOqobRDbaKHaaQAgMMyhioEadPwQXqB2gfyRFLddNCwd0pEQJXL9tWKgNx6qixDH80MsUSYmSTJE0n5ohOe+Z+76rf9yrgpL3t2dIztxhtL8fMJiZve4+Z519zzrn3v2dtba5O4QQb38yq+2AEKI7KNiFSAQFuxCJoGAXIhEU7EIkgoJdiETIXU5nM7sPwOcAZAH8ubv/cez1+XzOS6V80NZqNWk/b7WYA7RPJnoZ4/1iNvewHxE3EJM2zbKX4AVgkR1mc+HxzWbD7QBQXpiL7I2MPYCeUg+19fX2B9sXFuZpn3q9TG2ZyDHns/w0zuSKwfbe/nA7ADQj52K5xv3P5/hJl89F3utM+BzJZfn2FhbCfSYny5ifrwUH65KD3dpn6v8C8C8BnATwjJk97u4vsT6lUh537N4etM3NTNB9NWrVYHs2zwejtzcStK3IYWe4rVYN+5GPbK5Zr1FbPjdAbRYJ93yBn6hrRzYG24cGN9E+Bw78kNrg3P9dN9xMbXfd+s+C7c+98DTt8/rpg9TWW+QXq6sG1lNb37rrgu233L2N9pmpTlHboaPc/00b+fu5cYTbir3hi8tQ5IK0f18j2P4//vRHtM/lfIy/E8Bhdz/i7jUAXwZw/2VsTwixglxOsG8GcOKC/0922oQQVyCX85099Dnzl75ImNkeAHsAoBj5KCaEWFku585+EsCWC/6/GsDpt77I3fe6+253353P80kKIcTKcjnB/gyAHWZ2nZkVAHwIwOPL45YQYrm55I/x7t4ws48B+Bu0pbdH3f3FaCdzmJEZ7chNP1MoBdtzxci1KqJdmfOdVebD/gFAi8hQsdlxy0Wkt1x4RrVNgVomZ6ap7dzkZLC9XN7H/YjIa3094bEHgLHJ89T2xI//NtjeMi5rzdQq1NYT8WOmwvsNDYYlwJ5iWBUCgC2jfOZ8avqXPrz+f4ZHuB8Dg/ycW6iG5by5BX4OlHrDX4kzGX7iX5bO7u7fBfDdy9mGEKI76Ak6IRJBwS5EIijYhUgEBbsQiaBgFyIRLms2/mJxB+rNsBTVM9BH+1VILkaryaWOZoM/rVetcHmtvz8s1QCA12fC+2JZeQBaxq+nxVxEH8zwTLR8ictQtdlw5lixxGUcGJcA3XgizOnx49SWJ9lB1QUuvRUitU97CtyPaoZvs3YsnFyzUDtF+5SKa6ntqi1XU1tlluaAYWyW+5gthM+DWecZduMT4XO43uDvpe7sQiSCgl2IRFCwC5EICnYhEkHBLkQidHU2PmNAkSSvTM8s0H7m4ZnkWJJGLHFivnzxdeYAoFwLTxf39kdmupt8drS8wGuu1Svcj1ypTm1m4X65SA00j13ziXoCAD15rnjU6+FTK9PkfrScqysLkQSlnh6euFJeCCcGjZ3l+5pbOEFtg8P3Ulupl5f+mqmMUVulHB7jJrgCcW46PB6NJj9vdGcXIhEU7EIkgoJdiERQsAuRCAp2IRJBwS5EInRVemu2WpgniRp1roRgaE1YRquUuVzXjCQETE9zSWNmJpzsAgAjZFWPfq7yYXomIr3NcVkrX+BvzcJ8JHGFSIfu/LpeLfMkjVY9UkMvy2WeYj68TSvx7TW4G23dltCb5bZyeCUknJ3kSSbFYqTe3RSvuzdJ5DAAGD/HbYOD4fcmcgqjPB8+Lm9GlkTjmxNCvJ1QsAuRCAp2IRJBwS5EIijYhUgEBbsQiXBZ0puZHQMwC6AJoOHuu2Ovz5ihUApnPZVKPINqjix3VI9oNbUaP7Rqldd3Gx7hfgwOhtvHTvPt1Vo8Q61IxgIAIgllyEXGqrIQll4qFe5HqRgZq0jmlbe4NsSS2/KRmnzNekQ2ikiR5RLvNzUf9r/RjNSEW8vH98zYSWqrtXgWYyWiLVfKYamvGclgK1fD/sf6LIfO/l53P7cM2xFCrCD6GC9EIlxusDuA75vZc2a2ZzkcEkKsDJf7Mf5udz9tZhsAPGFmL7v7Dy58QecisAcAisXIusxCiBXlsu7s7n6683scwLcA3Bl4zV533+3uu/OxRdiFECvKJQe7mfWZ2cAbfwN4H4Dw8htCiFXncj7GbwTwLTN7Yzv/292/F+vQagELc2FpIJPlskWOeJnN80KPHpEgtu8aoraBPj4kM+fC8lVzbSTrKpJRlokUgawRaQUAhoZ5v7XrwrLR3Az3sVrmYzW8kS/LVTQuUc3MhSWvOmLLIPHtlSMy60KLj0eDLBHWLHNJcdb4vqo1LjeuHR6mtkjdTix4WLot5vj53WzNBtvdue+XHOzufgTArZfaXwjRXSS9CZEICnYhEkHBLkQiKNiFSAQFuxCJ0N213jLAYG/4+pKNZDXNz4ZlknwuUrCxxGWLFilCCAB149lhXghLVCMkGw4ATp/g+2IyJAA0nfuRK/GxWjsYlq+akfXtCpHt9cbGscX9b5Fss6F1vJhjmdeAxOw0zxqbOBfOigSA/t6w/znSDgDNFj+v6lVum54Oy2FAPNOyRNYlzA/x9+yqzevDfQq8IKbu7EIkgoJdiERQsAuRCAp2IRJBwS5EInR1Nt4B1FrhGcbZMT5buXY4PN3davLln+oWmWHu5UvxzEVmW5u18AxzqcBndgcGuG1NH0/gmJjiM93TE5FZ/GrYxxz4cfVHfKws8LGqkX0BwOBQMdheYFlNAIoRVeP8GJ+Z7unn4zhfDZ8jxYgCUY2dAwtcJelt8nHMFWPJUuEx9kjSUJlIF/VIoo7u7EIkgoJdiERQsAuRCAp2IRJBwS5EIijYhUiErkpvrWYLs3NhyaDZ5DLOPJEmZqa4LFTMc4kkm+W1zrKZyBJEpL1Wi9T9ynNbT4FLPOU6vw67x+TBsCzXihxzZYInmRSy/BTJZ3u4Hx6WvGJjXyvzY85YZImnaX7urB0JS4DlKj93qjU+viNDsUQeLnstVLmtRU6R6Unux+jGtcF256qs7uxCpIKCXYhEULALkQgKdiESQcEuRCIo2IVIhEWlNzN7FMCvAxh395s7bcMAvgJgK4BjAB5w98nFtpXJZDBQCss1Y7N8+aeF8kyw3Z1nO3kzslzQLL/GXbern9oqpNTZ1ByXcTxSp63a4LbSGn5sff0R+Wo6vM2p89zHVpZLPC3jkpGD23qHwmPcynCZbM36Xmq7rsht01NcOmzUiY+R9ZgG1vDzYzBSFw4tHk7HT/MMzeHh8BJbg5FsxFotHC8e0d6Wcmf/SwD3vaXtIQBPuvsOAE92/hdCXMEsGuyd9dYn3tJ8P4DHOn8/BuADy+yXEGKZudTv7Bvd/QwAdH5vWD6XhBArwYo/LmtmewDsAYBCgX8PFUKsLJd6Zx8zs1EA6PweZy90973uvtvdd+fzCnYhVotLDfbHATzY+ftBAN9eHneEECvFUqS3LwG4B8A6MzsJ4FMA/hjAV83sIwCOA/jNpewskzH0kqVuMpG7foYsx1PiCUhYt5Eb123kh91ocolqZi4s59W4qoJGnUuAw1fxrLGhYb7NapVvc5ZkCDYikoxX+TV/03Yu/9Qr3I+shW3ZHO+DDJfycgVu6+vn7+fZ8bDU11eMZPNFikNOz3E/Bvr4WF3VxyXdSSLdDkbk11IpbMtEsjYXDXZ3/zAx/epifYUQVw56gk6IRFCwC5EICnYhEkHBLkQiKNiFSISuFpysVut49cjJsNF4JlepJ3xNWj/KpauRkVj2D894atT4kPT1h2WNniL3/fgvuNRkkWvt3CyXeKbOc1ujTo4tkr1W7OcZZY3I2mHZXORe0QxLn1OTXNrM57iGmY+cqtaMZD8S6bNl/ByIqFdoRQpHzhf5eGzdyM+RzEw4a6/ViBUWDR+z+8UXTBVCvM1QsAuRCAp2IRJBwS5EIijYhUgEBbsQidBV6c3d0GqFJYh6ja/NNrI+vF7Xtp3hQn0AMHmGSzwTE9zWH15CCwAwOBQersmzXDIauYpLLr0DXFqZPMsllHpkbbk7r3tHsH3Hep5G97WDz1AbclzWOnKIH/f60XAGmEckr0aD33uqkezBZsSWK4Ul2NFtkcKiM1y2rZzhhVH76tw2WYkUxSRhWFvgMVEohc8Pj8jKurMLkQgKdiESQcEuRCIo2IVIBAW7EInQ1dn4Qi6LLWvXBG2HT43RfvOkRteLB2hRW9QrfEa1p8RnYk8c5TPMQyPhmelGlc+atiysJADA2Cner6ePz4JXFngyxh2bdgTb33fXu2if6Spfkung0RPUdu+uXdT2wqnXgu3Wy5WQRpmP1VWbR6jt2Gv83NnYGz7fNhW4SjKXjbwvgzxp6Nz5KWrL9/CkrUY9PCYD/bym3bCFbTlTIowQyaNgFyIRFOxCJIKCXYhEULALkQgKdiESYSnLPz0K4NcBjLv7zZ22TwP4KICznZd90t2/u+jOslkMrx0M2taWp2m/ybHww/3e4vLUQKQG3fz8PLXlSL07AKjMhfdX5ptDpcmN81ypwYaNA9RWr3AZ53B5Ntje+5PnaZ/3XcMltB35ddS269pt1Lbnz18Otk+cnaN93nX7rdS2dStfFbxCpFkAmJ4Iy2hnx3gSVbXE35g6kckAoJ7nWVQbNnH/fe4MMdAuyJWGgu1mr9M+S7mz/yWA+wLtn3X32zo/iwa6EGJ1WTTY3f0HACa64IsQYgW5nO/sHzOz/Wb2qJlFssCFEFcClxrsnwdwPYDbAJwB8Bn2QjPbY2bPmtmztTp/zFMIsbJcUrC7+5i7N929BeALAO6MvHavu+92992FfFcfxRdCXMAlBbuZjV7w7wcBHFwed4QQK8VSpLcvAbgHwDozOwngUwDuMbPb0BYHjgH4/aXsrOlNzDVmgrb+wbAkBwBzc2E5aX6ayyClIs8YWruOS3bjZ3kG2NrhsK1e5RrJ2Qm+vVYkM2/mPD+2jIWXVgKAd/7z3wm2z71+ivaZez2coQYAM3OT1HbuBN/mJ37rA8H2v/vZftqnb/N11LZpeD21lXdy2fbU8UPB9olTRO4CUOnj76fl+blTn+Xv9asnuCQ2Uw6P8cahcMYeAAxtvybYns0foX0WDXZ3/3Cg+ZHF+gkhriz0BJ0QiaBgFyIRFOxCJIKCXYhEULALkQhdfcqlWmvgtaPhx+zrTb6ET29fWEbbsJkXDayU+dN6M/Nc8oo993P0ZLjfugF+zbxpA8+umgfPKKvXuYxTLPKih7fe/k+C7c0yzyhrHXiW2p78DpeMTp96ido+9Nu/HWyfneBZb994IZwpBwDv/b3bqC32ptWILHq18eWY8i+9QG0DRX7O5Yzbpoz7OF0KS2yNApdY65Pngu3e5Oe97uxCJIKCXYhEULALkQgKdiESQcEuRCIo2IVIBHOPVLVbZgr5vG9cFy5qk89zOaxQCq9fVTcuTzXnuW1kG5c0cjVe6PHXZsMZTw+cPU37PL5hK7V9b4Bn+lmTZ73VuEqJX7nnV4Pt/+6999I+jSOHqe2pfT+itjPj/LjffePNwfZz0zyLrpWNZCOW+FhVz/O13ga2bw2239Dg59tv9PLikHnwwffIem5eiawHeDK8ZmH5NM/MO/7az4Ltv/XKCby4UAkGjO7sQiSCgl2IRFCwC5EICnYhEkHBLkQidDURJptzDA6FZzOHBvks+Kmz4Yf+K7PhWXoAmJ7jtt3Dw9T2qetvpLab3rkl2J4Z5zPMR4/wWpxfjywlZJHEoIzzY/vR34QX57l9Ex9fe/04td184yZq+40HQhXL2swiPLM+Cn7Me//nn1Lbhu07qW0NqccGAKMeniG/pZfXKPSdfFmr2i6eUJR5x03Uhv37qKn1xPeD7fnxE7TPzlo44aUUUdd0ZxciERTsQiSCgl2IRFCwC5EICnYhEkHBLkQiLGX5py0A/grAJgAtAHvd/XNmNgzgKwC2or0E1APuzjUoADkY1mfDkkd5YoH2K82F5YSBXn6terCPS01/WOG1wtacCct8AFA5FU5YyB09Rvv8WplLTafWFKntm5EkmSnjslwlF5a8nvvbf6B91hlPQLn7LE8Kyb3Ok2T6z58Nt5d5QsjvHeKnz8jLP6a2NSWe1NI/Ha55l3c+hlblSVS2iUuRtoPLtq1+XjcwOxdeviozxcfDe0bDhkx43IGl3dkbAD7h7rsA3AXgD8zsRgAPAXjS3XcAeLLzvxDiCmXRYHf3M+7+fOfvWQCHAGwGcD+AxzovewxAeCU/IcQVwUV9ZzezrQBuB/BTABvd/QzQviAA4J/3hBCrzpIflzWzfgDfAPBxd58x449svqXfHgB7AKCY13ygEKvFkqLPzPJoB/oX3f2bneYxMxvt2EcBBGev3H2vu+929935rIJdiNVi0eiz9i38EQCH3P3hC0yPA3iw8/eDAL69/O4JIZaLRWvQmdm7AfwDgANoS28A8Em0v7d/FcA1AI4D+E13D6/t1GHDUMn/zT3hDKX+4Ug9NrJ0zsbXeO2xjx7nckx223Zqy13L5RP7yU+C7X78EO8DLq+hxZfqOTscXhIIAM4PjFDbXCH89eq6Yj/tM7yGb896uCxnBf4t0HvD+8sOcj+y67kf6OVSqvfymoKtXFjqbTa4vNbK8K+ouWG+ZFc2w8cKeZ5l1yK786ee4tv73v8NNv/TY6/gufJCcIuLfmd39x8CYEcfrm4ohLji0JdoIRJBwS5EIijYhUgEBbsQiaBgFyIRulpwMp/P4Woir+TzXLZotsLy4L2H52mfwgCXSDJrNlIbDjxPTXb2VLj95l/hfW7jBQqxZTM1bR4KL5MFAJuLXMZBJZxl1zrHZUqQDDUAaJLChgCQ6eEymrXC0lZzjmc3+hG+nJQX+H3Jjfvo1bDNq2XeJyK91SKFUbMlLpdiLbc1rw6fq9ntvPBl9iO/EzZ87r/TPrqzC5EICnYhEkHBLkQiKNiFSAQFuxCJoGAXIhG6Kr3lMhkM9/YFbcUcLwLZOzYTbL9+LlIYcO51amue/A61LWzislzmhneEDTfsoH2wjks1mbGj1Nb6GZcAs1Oz1NasVoLth53LlINEngKA4XJ4ewBQrPHMwlYxfGpZnRd6RJ37YQWePdhCpHgk2V8mG8nYi2wPkWKfTT5UsEhRz1IpLKWebPLxmCe36cq587SP7uxCJIKCXYhEULALkQgKdiESQcEuRCJ0dTbeW456NZyoUavyWc6dL4eTOErOZzgbDb7MUAN8lrM0FV6KBwB6z00F2/3pZ2gfb3E/6pEliOqR2oAWuUZbNpzEsTXL1Y58hp8GWY8kmTifjc8g/N7E+ljEhhYfq0jlN8DD45EhyVXtPpGxt9j9kdvqkRn+h0nizZciu5ohLp5sRBKX+OaEEG8nFOxCJIKCXYhEULALkQgKdiESQcEuRCIsKr2Z2RYAfwVgE9rLP+1198+Z2acBfBTAGwXMPunu341tK5vLYmg4XIOuMc2lidFjYTmsthBOkAGA2LJW2YjqUqnwemw/yoflq/nNvF6c1bj0NjrLMye2z3Gb0QV6ADTC45iPSDIxmkS6avvBcWaNdIoIb4vsK0Zsq2GakZ1ZJBGmEPHkryNLZX1mMLx81c538GXKthTDTp5/+iXaZyk6ewPAJ9z9eTMbAPCcmT3RsX3W3f/bErYhhFhllrLW2xkAZzp/z5rZIQC8LKoQ4orkor6zm9lWALejvYIrAHzMzPab2aNmxj/LCiFWnSUHu5n1A/gGgI+7+wyAzwO4HsBtaN/5P0P67TGzZ83s2dkFXmxCCLGyLCnYzSyPdqB/0d2/CQDuPubuTW8/7PwFAHeG+rr7Xnff7e67B3ojixsIIVaURYPdzAzAIwAOufvDF7SPXvCyDwI4uPzuCSGWi6XMxt8N4HcBHDCzfZ22TwL4sJndhrbycQzA7y+2oUwmg1IpLDPkfswlg6GpcLZZNSJ1xOSpmnHbH/XyWmf7tmwItl+zayfts37TVmo79+qL1Lb9hzyT7j9GasZlyXG3Itf1mHQVGSo07eLHPxPVyWLb48S26eQAoscc2VuuxaW86ch4fCXPQ23baLju4QP/+t/SPn194fP0wKsPB9uBpc3G/xDhsY5q6kKIKws9QSdEIijYhUgEBbsQiaBgFyIRFOxCJELXC07WFsKy0Ttf4xlsuWL4YRwrh4tXtuHZSd8r9FDb94f5U7+3rOsPthcwR/uM9PN9VUbC2wOA72xZT213Hg0X4ASA95BCipEFjVCIZAjGcsaykX6XIvTFfIwk310Ssc3FClieuHaY2o6XeYbjqchA3kKWCHvl2Mu0z8jawWB7tc6fUtWdXYhEULALkQgKdiESQcEuRCIo2IVIBAW7EInQVekNmRyyvWHp4pl38cwxeyUsM5R+/grtM9jkAsq+DBd5cnxJNJSIBHhNXx/tUzv3Gt+ec8lucM0aavv70nlqu3cufGy5yLpysQywSz9Bwlu95H1dovbmi5SjDGGRPj0VLveedn7vzBR5NuUIybRszR+lfWqVsKTrdV6oVHd2IRJBwS5EIijYhUgEBbsQiaBgFyIRFOxCJEJXpTczoFAIp/+MXR3O/AGAr50Oy0bPb+CSV2OaSxA/b3IZylr8+lcYCMuGmzaECwa2t7dAbb+Y56W1a9UytZ1z/rZNjoYlu4mdN9E++SYvYJmLSF6ZZmQ9PWaLVbCM5di1ItJh5uJXgmuRNfEAIBO5B/bO8vezdvIwtVkfl4IbpIjltqFNtE+rGc6wy2Ui8h+1CCHeVijYhUgEBbsQiaBgFyIRFOxCJMKis/FmVgLwAwDFzuu/7u6fMrPrAHwZwDCA5wH8rrtHl2nNZrLo6wvPaBdLfEb470vha9JPIrPIcxk+s5uLVCAbmOG18PI94fp0ozfdQ/vMnz9HbeMnnqK2uSqfLX6uwZWGv6iEZ31PnDtN+2Qjk9mFDJ9FLhi3tcgMeTbL+1h0pj6yNFREMWBLOVmW3+eiS4cNcgXllRzv5xGhYbYZDsNaL69RWCoSW477t5Q7exXAve5+K9rLM99nZncB+BMAn3X3HQAmAXxkCdsSQqwSiwa7t3kjFzPf+XEA9wL4eqf9MQAfWBEPhRDLwlLXZ892VnAdB/AEgNcATLn7G5+jTwLYvDIuCiGWgyUFu7s33f02AFcDuBPArtDLQn3NbI+ZPWtmz07P8afChBAry0XNxrv7FIC/A3AXgCEze2Nm4WoAwRkgd9/r7rvdffeayIIJQoiVZdFgN7P1ZjbU+bsHwL8AcAjAUwDeWC3+QQDfXiknhRCXz1ISYUYBPGZmWbQvDl919/9jZi8B+LKZ/RcAPwPwyGIbyhcKuOrq8Fd7z3PJ4O5yuFbbDaMbaJ/5CpenWk2ugxwb4/XdDh48EGzfecMdtE9/H5dPXh+forbpiQlqq/ZwiecvMmH1M3OC1zObrXDFtF6PJYxEpCbWHikJZ8aNsUpyMcGO3c1iuTOFiIQ21M8TtsZJcgoA1Ce5pDs+MRvuY3xf2669PdheKDxO+ywa7O6+H8Avbdndj6D9/V0I8Y8APUEnRCIo2IVIBAW7EImgYBciERTsQiSCeUwLWe6dmZ0F8IvOv+sA8JSw7iE/3oz8eDP/2Py41t3XhwxdDfY37djsWXffvSo7lx/yI0E/9DFeiERQsAuRCKsZ7HtXcd8XIj/ejPx4M28bP1btO7sQorvoY7wQibAqwW5m95nZK2Z22MweWg0fOn4cM7MDZrbPzJ7t4n4fNbNxMzt4QduwmT1hZj/v/A5Xt1x5Pz5tZqc6Y7LPzN7fBT+2mNlTZnbIzF40s//Qae/qmET86OqYmFnJzJ42sxc6fvxRp/06M/tpZzy+YmY8VTSEu3f1B0AW7bJW2wAUALwA4MZu+9Hx5RiAdauw3/cAuAPAwQva/iuAhzp/PwTgT1bJj08D+MMuj8cogDs6fw8AeBXAjd0ek4gfXR0TtLN2+zt/5wH8FO2CMV8F8KFO+58B+PcXs93VuLPfCeCwux/xdunpLwO4fxX8WDXc/QcA3pqwfj/ahTuBLhXwJH50HXc/4+7Pd/6eRbs4ymZ0eUwifnQVb7PsRV5XI9g3Azhxwf+rWazSAXzfzJ4zsz2r5MMbbHT3M0D7pAPAK3OsPB8zs/2dj/kr/nXiQsxsK9r1E36KVRyTt/gBdHlMVqLI62oEe6gMyGpJAne7+x0A/hWAPzCz96ySH1cSnwdwPdprBJwB8Jlu7djM+gF8A8DH3Z2Xdum+H10fE7+MIq+M1Qj2kwC2XPA/LVa50rj76c7vcQDfwupW3hkzs1EA6PweXw0n3H2sc6K1AHwBXRoTM8ujHWBfdPdvdpq7PiYhP1ZrTDr7vugir4zVCPZnAOzozCwWAHwIAC+ctUKYWZ9Zu8iXmfUBeB+Ag/FeK8rjaBfuBFaxgOcbwdXhg+jCmFh73adHABxy94cvMHV1TJgf3R6TFSvy2q0ZxrfMNr4f7ZnO1wD8p1XyYRvaSsALAF7sph8AvoT2x8E62p90PgJgBMCTAH7e+T28Sn78NYADAPajHWyjXfDj3Wh/JN0PYF/n5/3dHpOIH10dEwC3oF3EdT/aF5b/fME5+zSAwwC+BqB4MdvVE3RCJIKeoBMiERTsQiSCgl2IRFCwC5EICnYhEkHBLkQiKNiFSAQFuxCJ8P8A6U1LCVG1pHgAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(img)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 32, 32])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 数据转换为tensor\n",
    "\n",
    "from torchvision import transforms\n",
    "to_tensor = transforms.ToTensor()\n",
    "\n",
    "img_t = to_tensor(img)\n",
    "img_t.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 整个数据\n",
    "\n",
    "tensor_cifar10 = datasets.CIFAR10(data_path,train = True,download=False,transform=transforms.ToTensor())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Tensor"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "img_t,_ = tensor_cifar10[99]\n",
    "type(img_t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f662cf57290>"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAfCklEQVR4nO2de4xd13Xev3Xf8+Rwhq8RRYmiSIuUZL1Kq0rlGrLSOqobRDbaKHaaQAgMMyhioEadPwQXqB2gfyRFLddNCwd0pEQJXL9tWKgNx6qixDH80MsUSYmSTJE0n5ohOe+Z+76rf9yrgpL3t2dIztxhtL8fMJiZve4+Z519zzrn3v2dtba5O4QQb38yq+2AEKI7KNiFSAQFuxCJoGAXIhEU7EIkgoJdiETIXU5nM7sPwOcAZAH8ubv/cez1+XzOS6V80NZqNWk/b7WYA7RPJnoZ4/1iNvewHxE3EJM2zbKX4AVgkR1mc+HxzWbD7QBQXpiL7I2MPYCeUg+19fX2B9sXFuZpn3q9TG2ZyDHns/w0zuSKwfbe/nA7ADQj52K5xv3P5/hJl89F3utM+BzJZfn2FhbCfSYny5ifrwUH65KD3dpn6v8C8C8BnATwjJk97u4vsT6lUh537N4etM3NTNB9NWrVYHs2zwejtzcStK3IYWe4rVYN+5GPbK5Zr1FbPjdAbRYJ93yBn6hrRzYG24cGN9E+Bw78kNrg3P9dN9xMbXfd+s+C7c+98DTt8/rpg9TWW+QXq6sG1lNb37rrgu233L2N9pmpTlHboaPc/00b+fu5cYTbir3hi8tQ5IK0f18j2P4//vRHtM/lfIy/E8Bhdz/i7jUAXwZw/2VsTwixglxOsG8GcOKC/0922oQQVyCX85099Dnzl75ImNkeAHsAoBj5KCaEWFku585+EsCWC/6/GsDpt77I3fe6+253353P80kKIcTKcjnB/gyAHWZ2nZkVAHwIwOPL45YQYrm55I/x7t4ws48B+Bu0pbdH3f3FaCdzmJEZ7chNP1MoBdtzxci1KqJdmfOdVebD/gFAi8hQsdlxy0Wkt1x4RrVNgVomZ6ap7dzkZLC9XN7H/YjIa3094bEHgLHJ89T2xI//NtjeMi5rzdQq1NYT8WOmwvsNDYYlwJ5iWBUCgC2jfOZ8avqXPrz+f4ZHuB8Dg/ycW6iG5by5BX4OlHrDX4kzGX7iX5bO7u7fBfDdy9mGEKI76Ak6IRJBwS5EIijYhUgEBbsQiaBgFyIRLms2/mJxB+rNsBTVM9BH+1VILkaryaWOZoM/rVetcHmtvz8s1QCA12fC+2JZeQBaxq+nxVxEH8zwTLR8ictQtdlw5lixxGUcGJcA3XgizOnx49SWJ9lB1QUuvRUitU97CtyPaoZvs3YsnFyzUDtF+5SKa6ntqi1XU1tlluaAYWyW+5gthM+DWecZduMT4XO43uDvpe7sQiSCgl2IRFCwC5EICnYhEkHBLkQidHU2PmNAkSSvTM8s0H7m4ZnkWJJGLHFivnzxdeYAoFwLTxf39kdmupt8drS8wGuu1Svcj1ypTm1m4X65SA00j13ziXoCAD15rnjU6+FTK9PkfrScqysLkQSlnh6euFJeCCcGjZ3l+5pbOEFtg8P3Ulupl5f+mqmMUVulHB7jJrgCcW46PB6NJj9vdGcXIhEU7EIkgoJdiERQsAuRCAp2IRJBwS5EInRVemu2WpgniRp1roRgaE1YRquUuVzXjCQETE9zSWNmJpzsAgAjZFWPfq7yYXomIr3NcVkrX+BvzcJ8JHGFSIfu/LpeLfMkjVY9UkMvy2WeYj68TSvx7TW4G23dltCb5bZyeCUknJ3kSSbFYqTe3RSvuzdJ5DAAGD/HbYOD4fcmcgqjPB8+Lm9GlkTjmxNCvJ1QsAuRCAp2IRJBwS5EIijYhUgEBbsQiXBZ0puZHQMwC6AJoOHuu2Ovz5ihUApnPZVKPINqjix3VI9oNbUaP7Rqldd3Gx7hfgwOhtvHTvPt1Vo8Q61IxgIAIgllyEXGqrIQll4qFe5HqRgZq0jmlbe4NsSS2/KRmnzNekQ2ikiR5RLvNzUf9r/RjNSEW8vH98zYSWqrtXgWYyWiLVfKYamvGclgK1fD/sf6LIfO/l53P7cM2xFCrCD6GC9EIlxusDuA75vZc2a2ZzkcEkKsDJf7Mf5udz9tZhsAPGFmL7v7Dy58QecisAcAisXIusxCiBXlsu7s7n6683scwLcA3Bl4zV533+3uu/OxRdiFECvKJQe7mfWZ2cAbfwN4H4Dw8htCiFXncj7GbwTwLTN7Yzv/292/F+vQagELc2FpIJPlskWOeJnN80KPHpEgtu8aoraBPj4kM+fC8lVzbSTrKpJRlokUgawRaQUAhoZ5v7XrwrLR3Az3sVrmYzW8kS/LVTQuUc3MhSWvOmLLIPHtlSMy60KLj0eDLBHWLHNJcdb4vqo1LjeuHR6mtkjdTix4WLot5vj53WzNBtvdue+XHOzufgTArZfaXwjRXSS9CZEICnYhEkHBLkQiKNiFSAQFuxCJ0N213jLAYG/4+pKNZDXNz4ZlknwuUrCxxGWLFilCCAB149lhXghLVCMkGw4ATp/g+2IyJAA0nfuRK/GxWjsYlq+akfXtCpHt9cbGscX9b5Fss6F1vJhjmdeAxOw0zxqbOBfOigSA/t6w/znSDgDNFj+v6lVum54Oy2FAPNOyRNYlzA/x9+yqzevDfQq8IKbu7EIkgoJdiERQsAuRCAp2IRJBwS5EInR1Nt4B1FrhGcbZMT5buXY4PN3davLln+oWmWHu5UvxzEVmW5u18AxzqcBndgcGuG1NH0/gmJjiM93TE5FZ/GrYxxz4cfVHfKws8LGqkX0BwOBQMdheYFlNAIoRVeP8GJ+Z7unn4zhfDZ8jxYgCUY2dAwtcJelt8nHMFWPJUuEx9kjSUJlIF/VIoo7u7EIkgoJdiERQsAuRCAp2IRJBwS5EIijYhUiErkpvrWYLs3NhyaDZ5DLOPJEmZqa4LFTMc4kkm+W1zrKZyBJEpL1Wi9T9ynNbT4FLPOU6vw67x+TBsCzXihxzZYInmRSy/BTJZ3u4Hx6WvGJjXyvzY85YZImnaX7urB0JS4DlKj93qjU+viNDsUQeLnstVLmtRU6R6Unux+jGtcF256qs7uxCpIKCXYhEULALkQgKdiESQcEuRCIo2IVIhEWlNzN7FMCvAxh395s7bcMAvgJgK4BjAB5w98nFtpXJZDBQCss1Y7N8+aeF8kyw3Z1nO3kzslzQLL/GXbern9oqpNTZ1ByXcTxSp63a4LbSGn5sff0R+Wo6vM2p89zHVpZLPC3jkpGD23qHwmPcynCZbM36Xmq7rsht01NcOmzUiY+R9ZgG1vDzYzBSFw4tHk7HT/MMzeHh8BJbg5FsxFotHC8e0d6Wcmf/SwD3vaXtIQBPuvsOAE92/hdCXMEsGuyd9dYn3tJ8P4DHOn8/BuADy+yXEGKZudTv7Bvd/QwAdH5vWD6XhBArwYo/LmtmewDsAYBCgX8PFUKsLJd6Zx8zs1EA6PweZy90973uvtvdd+fzCnYhVotLDfbHATzY+ftBAN9eHneEECvFUqS3LwG4B8A6MzsJ4FMA/hjAV83sIwCOA/jNpewskzH0kqVuMpG7foYsx1PiCUhYt5Eb123kh91ocolqZi4s59W4qoJGnUuAw1fxrLGhYb7NapVvc5ZkCDYikoxX+TV/03Yu/9Qr3I+shW3ZHO+DDJfycgVu6+vn7+fZ8bDU11eMZPNFikNOz3E/Bvr4WF3VxyXdSSLdDkbk11IpbMtEsjYXDXZ3/zAx/epifYUQVw56gk6IRFCwC5EICnYhEkHBLkQiKNiFSISuFpysVut49cjJsNF4JlepJ3xNWj/KpauRkVj2D894atT4kPT1h2WNniL3/fgvuNRkkWvt3CyXeKbOc1ujTo4tkr1W7OcZZY3I2mHZXORe0QxLn1OTXNrM57iGmY+cqtaMZD8S6bNl/ByIqFdoRQpHzhf5eGzdyM+RzEw4a6/ViBUWDR+z+8UXTBVCvM1QsAuRCAp2IRJBwS5EIijYhUgEBbsQidBV6c3d0GqFJYh6ja/NNrI+vF7Xtp3hQn0AMHmGSzwTE9zWH15CCwAwOBQersmzXDIauYpLLr0DXFqZPMsllHpkbbk7r3tHsH3Hep5G97WDz1AbclzWOnKIH/f60XAGmEckr0aD33uqkezBZsSWK4Ul2NFtkcKiM1y2rZzhhVH76tw2WYkUxSRhWFvgMVEohc8Pj8jKurMLkQgKdiESQcEuRCIo2IVIBAW7EInQ1dn4Qi6LLWvXBG2HT43RfvOkRteLB2hRW9QrfEa1p8RnYk8c5TPMQyPhmelGlc+atiysJADA2Cner6ePz4JXFngyxh2bdgTb33fXu2if6Spfkung0RPUdu+uXdT2wqnXgu3Wy5WQRpmP1VWbR6jt2Gv83NnYGz7fNhW4SjKXjbwvgzxp6Nz5KWrL9/CkrUY9PCYD/bym3bCFbTlTIowQyaNgFyIRFOxCJIKCXYhEULALkQgKdiESYSnLPz0K4NcBjLv7zZ22TwP4KICznZd90t2/u+jOslkMrx0M2taWp2m/ybHww/3e4vLUQKQG3fz8PLXlSL07AKjMhfdX5ptDpcmN81ypwYaNA9RWr3AZ53B5Ntje+5PnaZ/3XcMltB35ddS269pt1Lbnz18Otk+cnaN93nX7rdS2dStfFbxCpFkAmJ4Iy2hnx3gSVbXE35g6kckAoJ7nWVQbNnH/fe4MMdAuyJWGgu1mr9M+S7mz/yWA+wLtn3X32zo/iwa6EGJ1WTTY3f0HACa64IsQYgW5nO/sHzOz/Wb2qJlFssCFEFcClxrsnwdwPYDbAJwB8Bn2QjPbY2bPmtmztTp/zFMIsbJcUrC7+5i7N929BeALAO6MvHavu+92992FfFcfxRdCXMAlBbuZjV7w7wcBHFwed4QQK8VSpLcvAbgHwDozOwngUwDuMbPb0BYHjgH4/aXsrOlNzDVmgrb+wbAkBwBzc2E5aX6ayyClIs8YWruOS3bjZ3kG2NrhsK1e5RrJ2Qm+vVYkM2/mPD+2jIWXVgKAd/7z3wm2z71+ivaZez2coQYAM3OT1HbuBN/mJ37rA8H2v/vZftqnb/N11LZpeD21lXdy2fbU8UPB9olTRO4CUOnj76fl+blTn+Xv9asnuCQ2Uw6P8cahcMYeAAxtvybYns0foX0WDXZ3/3Cg+ZHF+gkhriz0BJ0QiaBgFyIRFOxCJIKCXYhEULALkQhdfcqlWmvgtaPhx+zrTb6ET29fWEbbsJkXDayU+dN6M/Nc8oo993P0ZLjfugF+zbxpA8+umgfPKKvXuYxTLPKih7fe/k+C7c0yzyhrHXiW2p78DpeMTp96ido+9Nu/HWyfneBZb994IZwpBwDv/b3bqC32ptWILHq18eWY8i+9QG0DRX7O5Yzbpoz7OF0KS2yNApdY65Pngu3e5Oe97uxCJIKCXYhEULALkQgKdiESQcEuRCIo2IVIBHOPVLVbZgr5vG9cFy5qk89zOaxQCq9fVTcuTzXnuW1kG5c0cjVe6PHXZsMZTw+cPU37PL5hK7V9b4Bn+lmTZ73VuEqJX7nnV4Pt/+6999I+jSOHqe2pfT+itjPj/LjffePNwfZz0zyLrpWNZCOW+FhVz/O13ga2bw2239Dg59tv9PLikHnwwffIem5eiawHeDK8ZmH5NM/MO/7az4Ltv/XKCby4UAkGjO7sQiSCgl2IRFCwC5EICnYhEkHBLkQidDURJptzDA6FZzOHBvks+Kmz4Yf+K7PhWXoAmJ7jtt3Dw9T2qetvpLab3rkl2J4Z5zPMR4/wWpxfjywlZJHEoIzzY/vR34QX57l9Ex9fe/04td184yZq+40HQhXL2swiPLM+Cn7Me//nn1Lbhu07qW0NqccGAKMeniG/pZfXKPSdfFmr2i6eUJR5x03Uhv37qKn1xPeD7fnxE7TPzlo44aUUUdd0ZxciERTsQiSCgl2IRFCwC5EICnYhEkHBLkQiLGX5py0A/grAJgAtAHvd/XNmNgzgKwC2or0E1APuzjUoADkY1mfDkkd5YoH2K82F5YSBXn6terCPS01/WOG1wtacCct8AFA5FU5YyB09Rvv8WplLTafWFKntm5EkmSnjslwlF5a8nvvbf6B91hlPQLn7LE8Kyb3Ok2T6z58Nt5d5QsjvHeKnz8jLP6a2NSWe1NI/Ha55l3c+hlblSVS2iUuRtoPLtq1+XjcwOxdeviozxcfDe0bDhkx43IGl3dkbAD7h7rsA3AXgD8zsRgAPAXjS3XcAeLLzvxDiCmXRYHf3M+7+fOfvWQCHAGwGcD+AxzovewxAeCU/IcQVwUV9ZzezrQBuB/BTABvd/QzQviAA4J/3hBCrzpIflzWzfgDfAPBxd58x449svqXfHgB7AKCY13ygEKvFkqLPzPJoB/oX3f2bneYxMxvt2EcBBGev3H2vu+929935rIJdiNVi0eiz9i38EQCH3P3hC0yPA3iw8/eDAL69/O4JIZaLRWvQmdm7AfwDgANoS28A8Em0v7d/FcA1AI4D+E13D6/t1GHDUMn/zT3hDKX+4Ug9NrJ0zsbXeO2xjx7nckx223Zqy13L5RP7yU+C7X78EO8DLq+hxZfqOTscXhIIAM4PjFDbXCH89eq6Yj/tM7yGb896uCxnBf4t0HvD+8sOcj+y67kf6OVSqvfymoKtXFjqbTa4vNbK8K+ouWG+ZFc2w8cKeZ5l1yK786ee4tv73v8NNv/TY6/gufJCcIuLfmd39x8CYEcfrm4ohLji0JdoIRJBwS5EIijYhUgEBbsQiaBgFyIRulpwMp/P4Woir+TzXLZotsLy4L2H52mfwgCXSDJrNlIbDjxPTXb2VLj95l/hfW7jBQqxZTM1bR4KL5MFAJuLXMZBJZxl1zrHZUqQDDUAaJLChgCQ6eEymrXC0lZzjmc3+hG+nJQX+H3Jjfvo1bDNq2XeJyK91SKFUbMlLpdiLbc1rw6fq9ntvPBl9iO/EzZ87r/TPrqzC5EICnYhEkHBLkQiKNiFSAQFuxCJoGAXIhG6Kr3lMhkM9/YFbcUcLwLZOzYTbL9+LlIYcO51amue/A61LWzislzmhneEDTfsoH2wjks1mbGj1Nb6GZcAs1Oz1NasVoLth53LlINEngKA4XJ4ewBQrPHMwlYxfGpZnRd6RJ37YQWePdhCpHgk2V8mG8nYi2wPkWKfTT5UsEhRz1IpLKWebPLxmCe36cq587SP7uxCJIKCXYhEULALkQgKdiESQcEuRCJ0dTbeW456NZyoUavyWc6dL4eTOErOZzgbDb7MUAN8lrM0FV6KBwB6z00F2/3pZ2gfb3E/6pEliOqR2oAWuUZbNpzEsTXL1Y58hp8GWY8kmTifjc8g/N7E+ljEhhYfq0jlN8DD45EhyVXtPpGxt9j9kdvqkRn+h0nizZciu5ohLp5sRBKX+OaEEG8nFOxCJIKCXYhEULALkQgKdiESQcEuRCIsKr2Z2RYAfwVgE9rLP+1198+Z2acBfBTAGwXMPunu341tK5vLYmg4XIOuMc2lidFjYTmsthBOkAGA2LJW2YjqUqnwemw/yoflq/nNvF6c1bj0NjrLMye2z3Gb0QV6ADTC45iPSDIxmkS6avvBcWaNdIoIb4vsK0Zsq2GakZ1ZJBGmEPHkryNLZX1mMLx81c538GXKthTDTp5/+iXaZyk6ewPAJ9z9eTMbAPCcmT3RsX3W3f/bErYhhFhllrLW2xkAZzp/z5rZIQC8LKoQ4orkor6zm9lWALejvYIrAHzMzPab2aNmxj/LCiFWnSUHu5n1A/gGgI+7+wyAzwO4HsBtaN/5P0P67TGzZ83s2dkFXmxCCLGyLCnYzSyPdqB/0d2/CQDuPubuTW8/7PwFAHeG+rr7Xnff7e67B3ojixsIIVaURYPdzAzAIwAOufvDF7SPXvCyDwI4uPzuCSGWi6XMxt8N4HcBHDCzfZ22TwL4sJndhrbycQzA7y+2oUwmg1IpLDPkfswlg6GpcLZZNSJ1xOSpmnHbH/XyWmf7tmwItl+zayfts37TVmo79+qL1Lb9hzyT7j9GasZlyXG3Itf1mHQVGSo07eLHPxPVyWLb48S26eQAoscc2VuuxaW86ch4fCXPQ23baLju4QP/+t/SPn194fP0wKsPB9uBpc3G/xDhsY5q6kKIKws9QSdEIijYhUgEBbsQiaBgFyIRFOxCJELXC07WFsKy0Ttf4xlsuWL4YRwrh4tXtuHZSd8r9FDb94f5U7+3rOsPthcwR/uM9PN9VUbC2wOA72xZT213Hg0X4ASA95BCipEFjVCIZAjGcsaykX6XIvTFfIwk310Ssc3FClieuHaY2o6XeYbjqchA3kKWCHvl2Mu0z8jawWB7tc6fUtWdXYhEULALkQgKdiESQcEuRCIo2IVIBAW7EInQVekNmRyyvWHp4pl38cwxeyUsM5R+/grtM9jkAsq+DBd5cnxJNJSIBHhNXx/tUzv3Gt+ec8lucM0aavv70nlqu3cufGy5yLpysQywSz9Bwlu95H1dovbmi5SjDGGRPj0VLveedn7vzBR5NuUIybRszR+lfWqVsKTrdV6oVHd2IRJBwS5EIijYhUgEBbsQiaBgFyIRFOxCJEJXpTczoFAIp/+MXR3O/AGAr50Oy0bPb+CSV2OaSxA/b3IZylr8+lcYCMuGmzaECwa2t7dAbb+Y56W1a9UytZ1z/rZNjoYlu4mdN9E++SYvYJmLSF6ZZmQ9PWaLVbCM5di1ItJh5uJXgmuRNfEAIBO5B/bO8vezdvIwtVkfl4IbpIjltqFNtE+rGc6wy2Ui8h+1CCHeVijYhUgEBbsQiaBgFyIRFOxCJMKis/FmVgLwAwDFzuu/7u6fMrPrAHwZwDCA5wH8rrtHl2nNZrLo6wvPaBdLfEb470vha9JPIrPIcxk+s5uLVCAbmOG18PI94fp0ozfdQ/vMnz9HbeMnnqK2uSqfLX6uwZWGv6iEZ31PnDtN+2Qjk9mFDJ9FLhi3tcgMeTbL+1h0pj6yNFREMWBLOVmW3+eiS4cNcgXllRzv5xGhYbYZDsNaL69RWCoSW477t5Q7exXAve5+K9rLM99nZncB+BMAn3X3HQAmAXxkCdsSQqwSiwa7t3kjFzPf+XEA9wL4eqf9MQAfWBEPhRDLwlLXZ892VnAdB/AEgNcATLn7G5+jTwLYvDIuCiGWgyUFu7s33f02AFcDuBPArtDLQn3NbI+ZPWtmz07P8afChBAry0XNxrv7FIC/A3AXgCEze2Nm4WoAwRkgd9/r7rvdffeayIIJQoiVZdFgN7P1ZjbU+bsHwL8AcAjAUwDeWC3+QQDfXiknhRCXz1ISYUYBPGZmWbQvDl919/9jZi8B+LKZ/RcAPwPwyGIbyhcKuOrq8Fd7z3PJ4O5yuFbbDaMbaJ/5CpenWk2ugxwb4/XdDh48EGzfecMdtE9/H5dPXh+forbpiQlqq/ZwiecvMmH1M3OC1zObrXDFtF6PJYxEpCbWHikJZ8aNsUpyMcGO3c1iuTOFiIQ21M8TtsZJcgoA1Ce5pDs+MRvuY3xf2669PdheKDxO+ywa7O6+H8Avbdndj6D9/V0I8Y8APUEnRCIo2IVIBAW7EImgYBciERTsQiSCeUwLWe6dmZ0F8IvOv+sA8JSw7iE/3oz8eDP/2Py41t3XhwxdDfY37djsWXffvSo7lx/yI0E/9DFeiERQsAuRCKsZ7HtXcd8XIj/ejPx4M28bP1btO7sQorvoY7wQibAqwW5m95nZK2Z22MweWg0fOn4cM7MDZrbPzJ7t4n4fNbNxMzt4QduwmT1hZj/v/A5Xt1x5Pz5tZqc6Y7LPzN7fBT+2mNlTZnbIzF40s//Qae/qmET86OqYmFnJzJ42sxc6fvxRp/06M/tpZzy+YmY8VTSEu3f1B0AW7bJW2wAUALwA4MZu+9Hx5RiAdauw3/cAuAPAwQva/iuAhzp/PwTgT1bJj08D+MMuj8cogDs6fw8AeBXAjd0ek4gfXR0TtLN2+zt/5wH8FO2CMV8F8KFO+58B+PcXs93VuLPfCeCwux/xdunpLwO4fxX8WDXc/QcA3pqwfj/ahTuBLhXwJH50HXc/4+7Pd/6eRbs4ymZ0eUwifnQVb7PsRV5XI9g3Azhxwf+rWazSAXzfzJ4zsz2r5MMbbHT3M0D7pAPAK3OsPB8zs/2dj/kr/nXiQsxsK9r1E36KVRyTt/gBdHlMVqLI62oEe6gMyGpJAne7+x0A/hWAPzCz96ySH1cSnwdwPdprBJwB8Jlu7djM+gF8A8DH3Z2Xdum+H10fE7+MIq+M1Qj2kwC2XPA/LVa50rj76c7vcQDfwupW3hkzs1EA6PweXw0n3H2sc6K1AHwBXRoTM8ujHWBfdPdvdpq7PiYhP1ZrTDr7vugir4zVCPZnAOzozCwWAHwIAC+ctUKYWZ9Zu8iXmfUBeB+Ag/FeK8rjaBfuBFaxgOcbwdXhg+jCmFh73adHABxy94cvMHV1TJgf3R6TFSvy2q0ZxrfMNr4f7ZnO1wD8p1XyYRvaSsALAF7sph8AvoT2x8E62p90PgJgBMCTAH7e+T28Sn78NYADAPajHWyjXfDj3Wh/JN0PYF/n5/3dHpOIH10dEwC3oF3EdT/aF5b/fME5+zSAwwC+BqB4MdvVE3RCJIKeoBMiERTsQiSCgl2IRFCwC5EICnYhEkHBLkQiKNiFSAQFuxCJ8P8A6U1LCVG1pHgAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.imshow(img_t.permute(1,2,0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 32, 32, 50000])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 数据标准化\n",
    "imgs = torch.stack([img_t for img_t,_ in tensor_cifar10],dim = 3)  # 在一个维度上把tensor堆叠起来\n",
    "imgs.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.4915, 0.4823, 0.4468])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "imgs.view(3,-1).mean(dim=1   # 计算各个通道的均值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.2470, 0.2435, 0.2616])"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "imgs.view(3,-1).std(dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Normalize(mean=(0.4915, 0.4823, 0.4468), std=(0.247, 0.2435, 0.2616))"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transforms.Normalize((0.4915, 0.4823, 0.4468),(0.2470, 0.2435, 0.2616))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 在数据集中标准化\n",
    "transformed_cifar10 = datasets.CIFAR10(data_path,train = True,download=False,\n",
    "                                       transform=transforms.Compose([\n",
    "                                           transforms.ToTensor(),\n",
    "                                           transforms.Normalize((0.4915, 0.4823, 0.4468),\n",
    "                                                                (0.2470, 0.2435, 0.2616))\n",
    "                                       ]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformed_cifar10_val = datasets.CIFAR10(data_path,train = False,download=False,\n",
    "                                       transform=transforms.Compose([\n",
    "                                           transforms.ToTensor(),\n",
    "                                           transforms.Normalize((0.4915, 0.4823, 0.4468),\n",
    "                                                                (0.2470, 0.2435, 0.2616))\n",
    "                                       ]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f662ceca510>"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAQSklEQVR4nO3df6xX9X3H8ee7/BAsrPyGO0BRQlPcUCB3xgbttNscmm7onKuuMZg4r1tqqolNRmg2mOmS2fgjpG3srkJKOyoyRWHWrSXEhrpM61URUCxQREVv+VElok70wnt/fA/ZhZ7P53vv93u+51z4vB4Jud/7eX/POW+PvPh+v+d8z+eYuyMip79PVd2AiJRDYRdJhMIukgiFXSQRCrtIIhR2kUQMbmZhM5sPLAMGAQ+6+7/Ueb7O8yVi3PAhueMH//eTkjvJd+5ZFqx98HH4r+m+X4fXOXxUuDY2Uhs6LH985JnhZXa8mj/+8RHo6fHc/zhr9Dy7mQ0CdgB/AuwFngOud/dXIsso7Im4adbk3PHlW98quZN8j3zvjGDtmTePBGt3/3N4nRf8Rbh2w5+Ha1Nm5o9fNie8zOXz8sd3vAwffpAf9mbexl8I7HL33e7+MbAaWNDE+kSkhZoJ+2TgzV6/783GRGQAauYze95bhd96m25mHUBHE9sRkQI0E/a9wNRev08B3j75Se7eCXSCPrOLVKmZt/HPATPM7BwzGwpcB6wvpi0RKVrDr+zu3mNmtwI/oXbqbYW7v1xYZ3JKGyhH3YcGxmdM+VZwmWs65gZrT226JFi7InLEvf3z4dqrb+aPv7g9vMy0wBH8PbvDyzR1nt3dnwSebGYdIlIOfYNOJBEKu0giFHaRRCjsIolQ2EUS0fCFMA1tTF+qkVPcLX8drr0fubItcGEbACPb8scP94SXWf7dQOEQ+CfFXwgjIqcQhV0kEQq7SCIUdpFEKOwiiWjqu/Eiqdm8NVwLXZwC8Mxr4dprO/PHP4w1cihWzKdXdpFEKOwiiVDYRRKhsIskQmEXSYTCLpIIXQgjcppx14UwIklT2EUSobCLJEJhF0mEwi6SCIVdJBFNXfVmZnuAw8BRoMfd24toSkSKV8Qlrpe5+8EC1iMiLaS38SKJaDbsDvzUzJ43s44iGhKR1mj2bfw8d3/bzCYAG8zsVXff1PsJ2T8C+odApGKFfTfezJYC77v73ZHn6LvxIi1W+HfjzezTZjby+GPgcmBbo+sTkdZq5m38ROAxMzu+nh+5+38V0pWIFE6XuIqcZnSJq0jiFHaRRCjsIolQ2EUSobCLJEJhF0mEwi6SCIVdJBEKu0giFHaRRCjsIolQ2EUSobCLJEJhF0mEwi6SCIVdJBEKu0giFHaRRCjsIolQ2EUSobCLJEJhF0mEwi6SCIVdJBEKu0gi6obdzFaY2X4z29ZrbIyZbTCzndnP0a1tU0Sa1ZdX9u8D808aWwRsdPcZwMbsdxEZwOqGPbvf+jsnDS8AVmaPVwJXFdyXiBSs0c/sE929GyD7OaG4lkSkFZq5ZXOfmFkH0NHq7YhIXKOv7PvMrA0g+7k/9ER373T3dndvb3BbIlKARsO+HliYPV4IrCumHRFpFXP3+BPMHgIuBcYB+4AlwOPAGuAs4A3gWnc/+SBe3rriGxORprm75Y3XDXuRFHaR1guFXd+gE0mEwi6SCIVdJBEKu0giFHaRRCjsIolQ2EUSobCLJEJhF0mEwi6SCIVdJBEKu0giWj55hQwMCyI1XZ+cBr2yiyRCYRdJhMIukgiFXSQRCrtIInQ0/jTzzcD4N/77tuAyY+ctC9bqTiwopwy9soskQmEXSYTCLpIIhV0kEQq7SCIUdpFE9OX2TyuALwH73f33s7GlwM3Agexpi939ybob0x1hKvNIpHbNnHDt4RfDtS9fMTZYs//8Tf2mpCWauSPM94H5OeP3ufvs7E/doItIteqG3d03oe9WiJzymvnMfquZbTGzFWY2urCORKQlGg37/cB0YDbQDdwTeqKZdZhZl5l1NbgtESlAQ2F3933uftTdjwEPABdGntvp7u3u3t5okyLSvIbCbmZtvX69GthWTDsi0ip1r3ozs4eAS4FxZrYXWAJcamazAQf2ALe0sEfph9VPbMkd37ziX4PLXL32u8HaM5FtXRs5vfb4uPzxqw5GVhixYNbkYG3d1rcaW2li6obd3a/PGV7egl5EpIX0DTqRRCjsIolQ2EUSobCLJEJhF0lE3aveCt2YrnpruYb+f678WbBkN14WrA2NrPLIgzfljv/D34RP5IQmywR4/cE7g7WvrVodrK176pXIWvtvQqQW+874LwvtIq6Zq95E5DSgsIskQmEXSYTCLpIIhV0kEQq7SCJ06q0Asf+oaZHa6wX3EeNvvx8ufv3vg6XP/Sh8RVzsdNITgfHHIst8FKk9FKkdi9R+d0r++IpD4WX+dGb4dCNE9uOM6eHaa5EJOP9nQ2R7/dMOdOnUm0jaFHaRRCjsIolQ2EUSobCLJEJH409SdIOxyzB+r+BtxXznkvOCtcE/D3d5WeTA9Gd/HDufMCIwHp4vzs48P7K+sDGBI+4AX+uZmDu+ZGr+OAD/Fj4DwWcv7mNX/XB53sxvwIbwBT4hOhovIgq7SCoUdpFEKOwiiVDYRRKhsIskou6pNzObCvwAmETtmoNOd19mZmOAh6ld67EH+Ct3f7fOugbEqbcB0QTwt5Fa+GZNxYvNq7YvumTshkI9DfUizWn21FsPcIe7zwQuAr5qZucBi4CN7j4D2Jj9LiIDVN2wu3u3u7+QPT4MbAcmAwuAldnTVgJXtapJEWlevz6zm9k0YA7wLDDR3buh9g8C8XeDIlKxundxPc7MRgCPAre7+3tmuR8L8pbrADoaa09EitKnV3YzG0It6KvcfW02vM/M2rJ6G7A/b1l373T3dndvL6JhEWlM3bBb7SV8ObDd3e/tVVoPLMweLwTWFd+eiBSlL6feLgZ+Dmzl/6f7Wkztc/sa4CzgDeBad3+nzroKPes1L1J7usgNSTkmXRKuzZwbqZ0Vro0OXN32buSk4vDIp9srvhRZLnSlHzAuckgrtLnpw8LLcCR3NHbqre5ndnd/Ggh9QP+jesuLyMCgb9CJJEJhF0mEwi6SCIVdJBEKu0giSp1wcqiZjw/UQuMQvuHOrib7KUfkhMfMW8K12EyPsckSXwtM6Lg2MnnhwcfDtajIKa/g9XL5p4xOD58JlyZ9Ply748/yx3dGbjW1c0fucHvXOrreO6AJJ0VSprCLJEJhF0mEwi6SCIVdJBEKu0giSj31Nt7MFwRqUyPLfS4w/uUm+ynF4D8I13qeK68PSYLu9SYiCrtIKhR2kUQo7CKJUNhFElHq0fhRZn5poBa7WdATLehFZKCYHRh/qcH1uY7Gi6RNYRdJhMIukgiFXSQRCrtIIhR2kUTUvSOMmU0FfgBMonb7p053X2ZmS4GbgQPZUxe7+5Oxdf0OEJpZ7VBfO67Qh4HxbZFlYjs4ckMjOc1cF6k1eoqtv/pyy+Ye4A53f8HMRgLPm9mGrHafu9/duvZEpCh9uddbN9CdPT5sZtuBya1uTESK1a/P7GY2DZhD7Q6uALea2RYzW2FmowvuTUQK1Oewm9kI4FHgdnd/D7gfmE7t237dwD2B5TrMrMvMuiKzYItIi/Up7GY2hFrQV7n7WgB33+fuR939GPAAcGHesu7e6e7t7t4euXu1iLRY3bCbmQHLge3ufm+v8bZeT7ua+EFpEalYX47GzwNuALaa2eZsbDFwvZnNBhzYA0TuZVQzdDBMG5dfG/XrPnRSgtzLhQaY8q5TlKI83MAyX5lzVbA2a1b+MfJv/3hNcJm+HI1/mvwMRM+pi8jAom/QiSRCYRdJhMIukgiFXSQRCrtIIkqdcPJsM18cqNU9b1eglZHajQVvK/av6bEG1xm7Sur8BtcpzXsjUju74G2dGRj/CDiqCSdF0qawiyRCYRdJhMIukgiFXSQRCrtIIvpy1VthBg2GEYGr3pZFrnq7reA+bix4fTGNnl6LuSBS0xVx1bm/xG2FJj+N0Su7SCIUdpFEKOwiiVDYRRKhsIskQmEXSUSpp96GDIFJbfm1H0ZOvd0ZGH+n6Y6KcU2kFtvBjUxCKANXd8Hr+8NI7aPAeGyKZ72yiyRCYRdJhMIukgiFXSQRCrtIIuoejTezYcAm4Izs+Y+4+xIzOwdYDYwBXgBucPePY+safuanmDVreG5tyosfBJf7Sb0mK3bzt1cHa9vW/0ew9vCGVYX38pnA+HuFb0laLXZHtGnD8scHHQkv05dX9iPAF939Amq3Z55vZhcBdwH3ufsM4F3gpj6sS0QqUjfsXnP81upDsj8OfBF4JBtfCYTvQicilevr/dkHZXdw3Q9sAH4FHHL3nuwpe4H820qKyIDQp7C7+1F3nw1MAS4EZuY9LW9ZM+swsy4z6/rNR5paQaQq/Toa7+6HgJ8BFwGjzOz4Ab4pwNuBZTrdvd3d28cOOxXufi5yeqobdjMbb2ajssfDgT8GtgNPAX+ZPW0hsK5VTYpI8/pyIUwbsNLMBlH7x2GNuz9hZq8Aq83sm8CLwPK6G5s4ngl3fCW3duf4x4PLbbtnd+74s3VbL8eSu8Kn3mbPKveGTDrFdvo4EKndtSQ/L7u+c0dwmbphd/ctwJyc8d3UPr+LyClA36ATSYTCLpIIhV0kEQq7SCIUdpFEmHt532ozswPA69mv44CDpW08TH2cSH2c6FTr42x3H59XKDXsJ2zYrMvd2yvZuPpQHwn2obfxIolQ2EUSUWXYOyvcdm/q40Tq40SnTR+VfWYXkXLpbbxIIioJu5nNN7NfmtkuM1tURQ9ZH3vMbKuZbTazrhK3u8LM9pvZtl5jY8xsg5ntzH6OrqiPpWb2VrZPNpvZlSX0MdXMnjKz7Wb2spndlo2Xuk8ifZS6T8xsmJn9wsxeyvr4p2z8HDN7NtsfD5vZ0H6t2N1L/QMMojat1bnAUOAl4Lyy+8h62QOMq2C7XwDmAtt6jX0LWJQ9XgTcVVEfS4Gvl7w/2oC52eORwA7gvLL3SaSPUvcJYMCI7PEQaldzXwSsAa7Lxr8H/F1/1lvFK/uFwC533+21qadXAwsq6KMy7r6J374v5QJqE3dCSRN4Bvoonbt3u/sL2ePD1CZHmUzJ+yTSR6m8pvBJXqsI+2TgzV6/VzlZpQM/NbPnzayjoh6Om+ju3VD7SwdMqLCXW81sS/Y2v+UfJ3ozs2nU5k94lgr3yUl9QMn7pBWTvFYR9ryJ6Ko6JTDP3ecCVwBfNbMvVNTHQHI/MJ3aPQK6gXvK2rCZjQAeBW5398om3cnpo/R94k1M8hpSRdj3AlN7/R6crLLV3P3t7Od+4DGqnXlnn5m1AWQ/91fRhLvvy/6iHQMeoKR9YmZDqAVslbuvzYZL3yd5fVS1T7Jt93uS15Aqwv4cMCM7sjgUuA5YX3YTZvZpMxt5/DFwOfF72bfaemoTd0KFE3geD1fmakrYJ2Zm1OYw3O7u9/YqlbpPQn2UvU9aNslrWUcYTzraeCW1I52/Ar5RUQ/nUjsT8BLwcpl9AA9Rezv4CbV3OjcBY4GNwM7s55iK+vghsBXYQi1sbSX0cTG1t6RbgM3ZnyvL3ieRPkrdJ8D51CZx3ULtH5Z/7PV39hfALuDfgTP6s159g04kEfoGnUgiFHaRRCjsIolQ2EUSobCLJEJhF0mEwi6SCIVdJBH/B1qcJzRYdiEMAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "img_t,_=transformed_cifar10[99]\n",
    "plt.imshow(img_t.permute(1,2,0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_map = {0:0,2:1}   # 把原来的标签映射过来\n",
    "\n",
    "class_names = ['airplane','bird']\n",
    "\n",
    "cifar2 = [(img,label_map[label]) for img,label in transformed_cifar10 \n",
    "          if label in [0,2]] # 把 0 2 图像及对应标签取出，赋值给img和映射的标签\n",
    "cifar2_val = [(img,label_map[label]) for img,label in cifar10_val \n",
    "              if label in [0,2]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "n_out = 2\n",
    "model = nn.Sequential(\n",
    "            nn.Linear(3072,512),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(512,n_out))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0900, 0.2447, 0.6652],\n",
       "        [0.0900, 0.2447, 0.6652]])"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 使用softmax分类\n",
    "def softmax(x):\n",
    "    return torch.exp(x)/torch.exp(x).sum()\n",
    "\n",
    "# 使用nn.Softmax() 时应该指定维度\n",
    "\n",
    "softmax = nn.Softmax(dim=1)\n",
    "x = torch.tensor([[1.0, 2.0, 3.0],\n",
    "[1.0, 2.0, 3.0]])\n",
    "\n",
    "softmax(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = nn.Sequential(\n",
    "            nn.Linear(3072,512),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(512,n_out),\n",
    "            nn.Softmax(dim=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f662cee2dd0>"
      ]
     },
     "execution_count": 80,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAZQElEQVR4nO2dfZzVdZXH30ceRAVDBYVAAxRLSwVDrVzclPDp5a4PZelWa+WKtbpbu+2WtbXZ42av0rTMxHTVXj6WVlZaEVlasgo+MeKUgKIiIw8hAioicvaPeymk3zkzc2fmztD383695jV3zmfO73fu795zf/f+zj3na+6OEOKvn216OwAhRHNQsgtRCEp2IQpByS5EISjZhSgEJbsQhdC/K85mdjRwIdAP+I67fzn7/yHDzHcZU609/kjiuG21eZtBscsA6xdq220f3+2dBg8LtR3ZtdLeP3nNXMvqUFu8ZkGoDR4Sl0RfFSowILCvS3yCwwvkZ4OsaPtyYN8h8ekJVgX29YnPC+FRhOxer16zIdTWv5Bs8vlEi3g2sL8MvtGtSrJG6+xm1g94BJgKLAZmA6e6+8ORz5hJ5p+eU63905HJzsZWm3fcJ07a4f3jlJiw/y6hduLk00Ntqp1Vad81eQrfxS9D7WMzjwu1N095MdRiLxge2OcnPsHhBWBwomUvIGsD+8GJT6NsTLQfB/ZFiU8ro0JtA3FC/3Lm0lB7vDXZ4f2JFnFrYF8B/lJ1snflbfzBwAJ3f9Td1wPXA8d3YXtCiB6kK8k+Cnhys78X121CiD5IV5K96q3CX3wmMLNpZjbHzOasWd6FvQkhukRXkn0xsPtmf48Glmz5T+4+3d0nufukIdEHSiFEj9OVZJ8NjDezsWY2EDgFuKV7whJCdDcNl97cfYOZnQ38nFrp7Qp3n5f59CO5uvv2xPGD1ebV+8RXRlfv98dQe+zpWFtwz6WxdkT14Tr1wMNDn7ckhbJvTRkaavOJr+wGBQ0gvkI+OvGJjyKsSLQ4+kavuo9LtP1DZQ6zQ+1/fvRUpX3w7pXmGsPiez3zorhKMnBiss04xLg+mBE90ElxrUt1dne/lbgIIIToQ+gbdEIUgpJdiEJQsgtRCEp2IQpByS5EIXTpanxnWQp8I9CmnBn7zQzqdQdMzGodcXntwU88EWu3PBprx3y00t5yblwWmnrw3FDLKi5JQx+LEy2q8ByT+IxItDcn2o7slqjR8c8KfXFZ6y5OCLUZP4rLm3efcFW1cFIcxSHfjONgv1haf0+s8ViiRVl4e+LTADqzC1EISnYhCkHJLkQhKNmFKAQluxCF0NSr8c89Bb/7VLW25xdiv7PeXW2/+KZknk82BuigRMv69m6rNs86K77i/nfJ5rKr8ecnWsbUwJ5dA8/GUu2YziOpnskH8MHg3k1l79DnoGQa3vJkQFbL7u8JNQiuxicH5HUjY23V5Fj7Q3bFPdlms7pLdGYXohCU7EIUgpJdiEJQsgtRCEp2IQpByS5EITS19MbTwBerpYWJ28V/HwhZt0gyIG3PZPWZhVnJ7tpq85JkUNv77ki2l8S/a4NLp0SLV0UlOYC90wWl4mW0fvqXw4T/xLrgARhL3Lw0i3iW3ynZkMIDYym852NmhB4zwjWNYMnFya6yDqUnEy1aPqeb0ZldiEJQsgtRCEp2IQpByS5EISjZhSgEJbsQhdCl0puZLQLWAC8DG9x9UsMbuybRonJYMg+MD8TShmTpn0O+GWt3R0vuzE/iyEg67NZeEGv/vEesRQvl3pmE0cKzoTYm0R5ItnlQMJ/uGX4d+lzHyfEGq9YM7hDvqDZv2CH0WHLxD+PNZZ1tOyVaH1jBuDvq7Ie7e7YkmBCiD6C38UIUQleT3YFfmNm9ZjatOwISQvQMXX0bf6i7LzGzXYEZZvZ7d3/FF0TrLwJ6IRCil+nSmd3dl9R/LwN+QMWy3O4+3d0ndeninRCiyzSc7Ga2g5kN2XQbOBJ4qLsCE0J0L+bujTmajaN2Nofax4Fr3T3oafuTT2M7+3xgvzrxydYtSrrNXntprJ0R2N+Q7GpFMrBx2j1PhdrzLfE2Dzg91qIGqqyr8NBE+1CiRR12ACOprg+28HLo8565SU3xgGTSI19JtAbI7tgRiRbPxIRZiRZ1xDXYDefulYXKhj+zu/ujwAGN+gshmotKb0IUgpJdiEJQsgtRCEp2IQpByS5EITRcemtoZ42W3o4J7EMSn6wT7ZlEi4ZbAhwW2JO1496eVJOSmZhcfl8iZsMLg0GVr0/WGkt6zdKyYtI8yJjA3souoc9bZ+4bb/BtUcshwOxYisph+ySbe2eiZQNJ2xItWCewJ4hKbzqzC1EISnYhCkHJLkQhKNmFKAQluxCF0NzlnzKySBYE9n9scF83JtrNiRatGDQmdrnprGR7yRJP28SrJDEiWe5ofGB/bxLGXomWkY1Vi7QN/DF2mrF3Y4FcmVyNDzj+tFgbkfhdemYixitK9Ql0ZheiEJTsQhSCkl2IQlCyC1EISnYhCkHJLkQh9J3S24ZEWxPYs6aEjGQGXXpEpgb2rCHn6US7Kgnjw7E2eUCyzYBsyZ7obrVHI4c/a2fhvLhJhmSW37dPOzrU9uJnlfbseCxKtNQxew73AXRmF6IQlOxCFIKSXYhCULILUQhKdiEKQckuRCG0W3ozsyuA44Bl7v6Gum1n4AZq/V6LgHe6ezbZrWtE5atrE5+kayxsDYN8dt1OgT3rvss67JLlfdYnc+YWjYu10YF9ELuFPrexNNSyOXk/TrSoUTEn21u8ftKYZAZd9JBl8W0g7r474MOPhNqD+yUb/WyiRc/VrEQ8PLD/JnbpyJn9SmDLQuY5wEx3Hw/MrP8thOjDtJvs9fXWV25hPp4/fyXkKuCEbo5LCNHNNPqZfTd3bwOo/961+0ISQvQEPf51WTObBkzr6f0IIXIaPbMvNbORAPXfy6J/dPfp7j7J3Sc1uC8hRDfQaLLfAmya4nUa8KPuCUcI0VO0u/yTmV0HvBUYBiwFPgP8kFpRaQ/gCeBkd9/yIl7Vtpq31lRGNlEw61KLSivZh5TtEi1pN8uWjTqZHZKNVq93NCwpva1gbqj9X7Knr7+QiF8L7FcnPvO/E2v7xJM73/Xwi6E2ObC/jv1Dn4O4MNQ2cGmo9Sd+0GazZ6g9HRz/tcRlvt/7wkr7tQctZumcFyuXf2r3M7u7nxpIU9rzFUL0HfQNOiEKQckuRCEo2YUoBCW7EIWgZBeiEPrOwMm+QtZp1BLYP5H4fDWWTkvKa1G3FsAi4sGMw4IST//koX4g2dfXs29Q/CrRosGMWVchj8bS1PiBWUVceosqqYOTcmN/PhZqI3ki1PYO640whXeHWjS6cyVPhR4729sq7XcSf3dNZ3YhCkHJLkQhKNmFKAQluxCFoGQXohCU7EIUwtZdesuiz9bdWpVo6WJkAcngyLzUFNfe+nNisrt4suHooJtrBX8Mfe68L5Rg1oxY6/Z1z74UKofstG2o/VuyxSjEbODknckAy+zp8TneE2rjkm3Cqyuts3ku9DiKw5PtVaMzuxCFoGQXohCU7EIUgpJdiEJQsgtRCFv31fiGrvjS2BX3RqkeCVeTnomvPr+w7pRQ22tYv3ijwSPaP6kYnHHglgv+/Jn3Hxj7LYiHCtPy8+qmlp/eeF68QX4YKpM3xM0uR/1p9ulf8rU/rWXyShpZWQmgLdEeS7TRyVy76KHJ5v/NeeHKSnvbxniIos7sQhSCkl2IQlCyC1EISnYhCkHJLkQhKNmFKIR2S29mdgVwHLDM3d9Qt50LnAEsr//bJ9391p4Kss+TlNd2Xvvfofa9i+MlgYYPjctrq8bH+1sbVF4WzI9LV2PGx00mg4bG+5p8RLxS94i3VGu3nfTh0GfjzXHpbVZS13o4KK8BTAjsYxkV+jyZzH4bkqTMBuLH7H+TOXmjA/sxoQcM2q664enabZ4NfTpyZr8SqCrEXuDuE+o/5Sa6EFsJ7Sa7u98BtLtooxCib9OVz+xnm9lcM7vCzLLJx0KIPkCjyX4JsCe1j0RtxAv0YmbTzGyOmc1pcF9CiG6goWR396Xu/rK7bwQuAw5O/ne6u09y93h6vRCix2ko2c1s5GZ/ngg81D3hCCF6io6U3q4D3goMM7PFwGeAt5rZBMCBRcCZPRhjyKvHxCWjsYeFbzbovy6+27+58fbOBzL230Np5WOTY7/lj4fSsvE7hFrb03HZaGXLI9XC3Hmhz7y18awz1salnJsOmhhqAydWlxU33pzMtEv4XbT0FvCtxG9IYF+elNf2SbY3NWm1HJpo2djDaKLgwWQdgh+stG7H34Ye7Sa7u59aYb68PT8hRN9C36ATohCU7EIUgpJdiEJQsgtRCEp2IQqhqQMnR+3yav7l+OqSwaDD4jLOoIn7VtoPHzsu9Bkc1VxIm9Q4acTZoTbzouurhajcBdDyRBJIXF5jRbwm08rluyV+1YMeSUpNsEuiJWtD3Rl39K2/M9rmq5J9JSSlt2x+6M8C+8IvJE7ZVMlkAOeZp8fab5NNRvG/JRnAGQcSl1F1ZheiEJTsQhSCkl2IQlCyC1EISnYhCkHJLkQhNLX0NmLMSD5++aebuctO07oiWRSNPwb2nzS2s2xXrVk57B2xNPTN1fZVSZmPpDyYrOeWEx2ryN442Zpo4RM8e+ZnbXRJS9ylyXDOsLUNmDe22v7jAbNCny8ytdK+OglBZ3YhCkHJLkQhKNmFKAQluxCFoGQXohCaejV+a2BFS9Z80Eyyq9aXxtKqaA5aPB8NggafvkTyTJ33o8TvsGrzG8+JXe59MtlesLwWAJnfsZ33u3dx7HJJcL+y2onO7EIUgpJdiEJQsgtRCEp2IQpByS5EISjZhSiEjiz/tDtwNTAC2AhMd/cLzWxn4AZgDLUloN7p7s/0XKidYz1LQm1gUtbq3xIvd7S+SxE1i7/SxXqmJdruiRZUHFuSZ+prk/l0Q9fEWmtSlhu0XawtC+Ylvj4ey8i6F6rtvjH26ciZfQPwUXffB3gTcJaZ7QucA8x09/HAzPrfQog+SrvJ7u5t7n5f/fYaoBUYBRwPXFX/t6uAE3oqSCFE1+nUZ3YzGwNMBO4GdnP3Nqi9IAC7dndwQojuo8PJbmaDgZuAj7h71iO/pd80M5tjZnOWL1/eSIxCiG6gQ8luZgOoJfo17n5z3bzUzEbW9ZEEX8t19+nuPsndJw0fPrw7YhZCNEC7yW5mRu0Sb6u7n7+ZdAtwWv32aUDWjiCE6GU60vV2KPBeoMXMHqjbPgl8GbjRzE6nNsTs5J4JEVYG9rVESx3BKv9lqI1gaag939GgRFM55OJYu/vnsbZjsEpS9sRvS0byvX+P/UPtxD3mhlrWc/ipwO1NU2KfaKTdw8npu91kd/ffAhbISThCiL6EvkEnRCEo2YUoBCW7EIWgZBeiEJTsQhTCVjFwcufAPphxoc/Tv3oq1G5bcWeobT84juP5bLkm0XWOadDv/lja6ahqe9aeeeIesXYy24baoGSbtyfaoUdU27NmvuvuqravTJ6jOrMLUQhKdiEKQckuRCEo2YUoBCW7EIWgZBeiELaK0lsjDBs7KtTGHBFP8pvYEpflfvfF6t6lN348juPeWMprNfMT7dpso03kzYk2q4HtfSqWpvKqUJtwTvw0XhAMF53t8b7WRW1fwPnMDrWscpgs28bkYH/LkxiffKzavj6ZiqozuxCFoGQXohCU7EIUgpJdiEJQsgtRCE29Gr+ReMbb2mA5G4ChwdI5/Xku9Bk3Lm6SWbvmjlCLrrhntF6aiMcmWjZZe3ynw2g+qxrwGZ1oydJKXzgiXpaLfZJtBlf4t0kanm4IrnQDkDSa/OwtsXZ0ssnJgX1VUhVYdVK1/bavxT46swtRCEp2IQpByS5EISjZhSgEJbsQhaBkF6IQ2i29mdnuwNXACGrVs+nufqGZnQucwZ8LSJ9091uzbW0DbB9oK5IyzsCg9LaMn4Q+37vhlFA7O5bSV7+Ngf35rATVaNPKjAb9mknnq5QQPJYAvC/Rnk60bMDbwdXmjdkQuqyJ552xtPCrsXZx3F/FxGCVxPcTN3O1bFc9Y7FfV5Z/ovaQftTd7zOzIcC9ZrbpqXiBuyd3UQjRV+jIWm9tQFv99hoza4XkJUcI0Sfp1Gd2MxsDTATurpvONrO5ZnaFme3UzbEJIbqRDie7mQ0GbgI+4u6rgUuAPYEJ1M78lV/UM7NpZjbHzOYsX559P1QI0ZN0KNnNbAC1RL/G3W8GcPel7v6yu28ELiO4FOLu0919krtPGj58eHfFLYToJO0mu5kZcDnQ6u7nb2Yfudm/nQg81P3hCSG6i45cjT8UeC/QYmYP1G2fBE41swmAA4uAM9vb0FrWcRetlVrbkwtDv5aHq+3fvT2uod3wi/aiqSYqr/Up/jXRLurmfX0mlgbuF2vr3xEI2Wy9RknKYWEn3YrE58ZEy4rLDS4P9o2g43O/oLwGcNncavtLSfdoR67G/xaoarZLa+pCiL6FvkEnRCEo2YUoBCW7EIWgZBeiEJTsQhRCUwdOrn5pBTParqzUWu66JvRbO7+6BHH7/cnOsqGBWzlT3hVrM7u79HZVLK3Puv0OCuzx6kmNMzbRdg/sAxrcV4PltWyA6IPB8/gHyQDLYcH9Wj4w9tGZXYhCULILUQhKdiEKQckuRCEo2YUoBCW7EIXQ1NLbi8+tZP491SW2/kPiDp9hwdDAyckaXzP/M9Z2jCVWJ1rEUcfE2s9va2CDwJSodAVMnBhrM6OOuEZLcosSbWiiRaWmbEhlVkrNaGTw5eGJ9g+J1ugA0azbL1gr8MvJ2ncHHFVtf6Zf7KMzuxCFoGQXohCU7EIUgpJdiEJQsgtRCEp2IQqhuaW3Z19iwa3VJbbBUXcSsDiIckRSnjr+h7G2Ihk2uCqJY90d1fZZjZZjEmYm3WEzz0kcg2nd2387dnn+3GR7yTDH17871vYKyqWDkl39IFjzDGB9NvFwdKJFj/W6xCcp6fYIUckxOVgtQaffxuR+6cwuRCEo2YUoBCW7EIWgZBeiEJTsQhRCu1fjzWwQcAewbf3/v+/unzGzscD1wM7AfcB73X19tq1dBsH7gwaJ3ydXwaN+iw3JItEjDoy1p++LtdbkyvrGynVqe4FsDtrV1ebnk3lxb/x8rD2ZLLw777xEO7Lavn3SxPOl42OtNdF+6bH2eLSUU1vsw8hESypADc+ne6bzLv2DK/UvJafvjpzZXwSOcPcDqC3PfLSZvQk4D7jA3cdTC/f0zoUrhGgm7Sa719j0mjWg/uPAEcD36/argBN6JEIhRLfQ0fXZ+9VXcF0GzAAWAqvcfVMn8WJgVM+EKIToDjqU7O7+srtPoPZdpYOp/o5R5ScnM5tmZnPMbM7aRj/TCCG6TKeuxrv7KuDXwJuAoWa26QLfaGBJ4DPd3Se5+6TBg7sSqhCiK7Sb7GY23MyG1m9vB7wNaAVuB95R/7fTgOSbzUKI3sbck7oFYGb7U7sA14/ai8ON7v45MxvHn0tv9wPvcfcXs21NGGn+i+Ca/eKPbxv6Xfe16s1emzRHrEqaGYYmZZdVM2Lt+VjqfoYlWlb+aXDmXcjkREuaLnYMSm+rk2W5XvOBWDtuSqwFvT8AXPRotX3lhYlTUrYlKUWmS45lQd4c2LPZelFj0zTw37tVSe3W2d19LvAX1VF3f5Ta53chxFaAvkEnRCEo2YUoBCW7EIWgZBeiEJTsQhRCu6W3bt2Z2XLg8fqfw4gnhDUTxfFKFMcr2drieI27Vxb6mprsr9ix2Rx3n9QrO1cciqPAOPQ2XohCULILUQi9mezTe3Hfm6M4XonieCV/NXH02md2IURz0dt4IQqhV5LdzI42sz+Y2QIzyxYz6uk4FplZi5k9YGZzmrjfK8xsmZk9tJltZzObYWbz67+TcZo9Gse5ZvZU/Zg8YGbHNiGO3c3sdjNrNbN5Zvbhur2pxySJo6nHxMwGmdk9ZvZgPY7P1u1jzezu+vG4wcwGdmrD7t7UH2qtsguBccBA4EFg32bHUY9lETCsF/Z7GLVGyoc2s30FOKd++xzgvF6K41zgP5p8PEYCB9ZvDwEeAfZt9jFJ4mjqMQEMGFy/PQC4m9rAmBuBU+r2bwMf6sx2e+PMfjCwwN0f9dro6euBZFDwXx/ufgewcgvz8dTmBkCTBngGcTQdd29z9/vqt9dQG44yiiYfkySOpuI1un3Ia28k+yjgyc3+7s1hlQ78wszuNbNpvRTDJnZz9zaoPemAXXsxlrPNbG79bX6Pf5zYHDMbQ21+wt304jHZIg5o8jHpiSGvvZHsVVM0eqskcKi7HwgcA5xlZof1Uhx9iUuAPamtEdAGNG1pDDMbDNwEfMTdVzdrvx2Io+nHxLsw5DWiN5J9MbD5+i/hsMqext2X1H8vA35A707eWWpmIwHqv5f1RhDuvrT+RNsIXEaTjomZDaCWYNe4+6ZBTU0/JlVx9NYxqe+700NeI3oj2WcD4+tXFgcCpwC3NDsIM9vBzIZsug0cCTyUe/Uot1Ab3Am9OMBzU3LVOZEmHBMzM+ByoNXdz99MauoxieJo9jHpsSGvzbrCuMXVxmOpXelcCPxXL8Uwjlol4EFgXjPjAK6j9nbwJWrvdE4HdgFmAvPrv3fupTi+C7QAc6kl28gmxPE31N6SzgUeqP8c2+xjksTR1GMC7E9tiOtcai8s/73Zc/YeYAHwPWDbzmxX36ATohD0DTohCkHJLkQhKNmFKAQluxCFoGQXohCU7EIUgpJdiEJQsgtRCP8P1U8WvRMTZVkAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 训练之前测试看一下model的输出是什么\n",
    "\n",
    "img, _ = cifar2[0]\n",
    "\n",
    "plt.imshow(img.permute(1,2,0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.3347, 0.6653]], grad_fn=<SoftmaxBackward>)"
      ]
     },
     "execution_count": 81,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 制作bathszie\n",
    "\n",
    "img_batch = img.view(-1).unsqueeze(0)\n",
    "out = model(img_batch)                         # 经过训练（loss函数会指引输出的结果中对应正确类别的标签数字最大）\n",
    "out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1])"
      ]
     },
     "execution_count": 82,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "_,index = torch.max(out,dim=1)\n",
    "index"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 如何选取loss函数\n",
    "\n",
    "我们的目的是最大化与正确类别对应的概率，比如[1,0]代表飞机，[0,1]代表鸟。因此在训练时，如果样本是飞机，应该使输出中第一个元素（索引为[0]）最大，我们称此[0]对应的数字为可能性（也就是概率）。我们的目的是最小化loss，因此如果输出是飞机的可能性很小（而实际上就是飞机），应该让loss很大，代表此时的loss不是我们想要的程度，随着训练，loss变小，这表明此时的可能性变大。\n",
    "\n",
    "根据这个思路，loos应该是当可能性很小的时候loss很大，当可能性很大时loss很小\n",
    "\n",
    "pytorch中的nn.NLLLoss就是这样一个函数，但是它的输入要求为对数，因此softmax的输出也应该是对数，可以用logsoftmax来代替"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = nn.Sequential(\n",
    "            nn.Linear(3072,512),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(512,n_out),\n",
    "            nn.LogSoftmax(dim=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = nn.NLLLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.8193, grad_fn=<NllLossBackward>)"
      ]
     },
     "execution_count": 88,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 测试loss\n",
    "img,label = cifar2[0]\n",
    "\n",
    "out = model(img.view(-1).unsqueeze(0))  \n",
    "\n",
    "loss(out,torch.tensor([label]))    # loss的输入: 第一个参数为softmax的输出；第二个参数是groundtruth的标签（飞机为0，鸟为1）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.5812, -0.8193]], grad_fn=<LogSoftmaxBackward>)"
      ]
     },
     "execution_count": 89,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 90,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[[ 0.6139, -0.3228,  ..., -0.2752, -0.5451],\n",
       "          [ 0.6615, -0.1482,  ..., -0.3228, -0.5768],\n",
       "          ...,\n",
       "          [ 0.5980,  0.4393,  ..., -0.4340,  0.0265],\n",
       "          [ 0.9156,  0.8044,  ..., -0.5451, -0.0529]],\n",
       " \n",
       "         [[ 1.3369,  0.2740,  ...,  0.3867,  0.0968],\n",
       "          [ 1.4497,  0.5961,  ...,  0.3062,  0.0646],\n",
       "          ...,\n",
       "          [ 0.5478,  0.6605,  ...,  0.4028,  0.8860],\n",
       "          [ 0.4834,  0.9504,  ...,  0.1613,  0.7572]],\n",
       " \n",
       "         [[-0.4487, -0.7935,  ..., -0.6736, -0.8535],\n",
       "          [-0.4487, -0.9734,  ..., -0.6286, -0.8535],\n",
       "          ...,\n",
       "          [-0.4337, -0.4787,  ..., -1.3032, -0.9884],\n",
       "          [-0.1789,  0.0310,  ..., -1.3182, -1.0484]]]),\n",
       " 1)"
      ]
     },
     "execution_count": 91,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
